package org.bsc.langgraph4j.serializer.std;

import lombok.NonNull;
import org.bsc.langgraph4j.serializer.Serializer;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.*;
import java.util.stream.Collectors;

import static java.lang.String.format;

public class SerializerMapper {
    static final Serializer<Object> DEFAULT_SERIALIZER = new Serializer<Object>() {
        @Override
        public void write(Object object, ObjectOutput out) throws IOException {
            out.writeObject(object);
        }

        @Override
        public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
            return in.readObject();
        }
    };

    static class Key {
        private final String _className;
        private final Class<?> _clazz;

        public static Key of(Class<?> clazz) {
            return new Key(clazz);
        }
        public static Key of(String  className ) {
            return new Key(className);
        }

        private Key(Class<?> clazz) {
            _className = clazz.getName();
            _clazz = clazz;
        }

        private Key(String className) {
            _className = className;
            _clazz = null;
        }

        String getTypeName() { return _className; }

        Class<?> getType() { return _clazz; }

        @Override
        public boolean equals(Object o) {
            return Objects.equals( o, _className );
        }

        @Override
        public int hashCode() {
            return Objects.hash(_className);
        }
    }
    private final Map<Key, Serializer<?>> _serializers = new HashMap<>();

    public SerializerMapper register(@NonNull  Class<?> clazz, @NonNull  Serializer<?> serializer ) {
        _serializers.put( Key.of(clazz), serializer);
        return this;
    }

    public boolean unregister( @NonNull  Class<? extends Serializer<?>> clazz ) {
        Objects.requireNonNull( clazz, "Serializer's class cannot be null" );
        Serializer<?> serializer = _serializers.remove( Key.of(clazz) );
        return serializer != null;
    }

    @SuppressWarnings("unchecked")
    public Optional<Serializer<Object>> getSerializer( @NonNull Class<?> clazz ) {
        Serializer<?> ser = _serializers.get( Key.of(clazz) );

        return ( ser != null ) ?

            Optional.of((Serializer<Object>)ser) :

            _serializers.entrySet().stream()
                    .filter( e -> e.getKey().getType().isAssignableFrom(clazz) )
                    .findFirst()
                    .map( e -> (Serializer<Object>)e.getValue() )
                ;

    }

    @SuppressWarnings("unchecked")
    public Optional<Serializer<Object>> getSerializer( @NonNull String className ) {
        return Optional.ofNullable((Serializer<Object>)_serializers.get( Key.of(className) ));
    }

    public Serializer<Object> getDefaultSerializer() {
        return DEFAULT_SERIALIZER;
    }

    protected final ObjectOutput objectOutputWithMapper(@NonNull  ObjectOutput out) {

        final ObjectOutputWithMapper mapperOut ;
        if( out instanceof ObjectOutputWithMapper ) {
            mapperOut = (ObjectOutputWithMapper)out;
        } else {
            mapperOut = new ObjectOutputWithMapper( out, this );
        }

        return mapperOut;
    }

    protected final ObjectInput objectInputWithMapper(@NonNull  ObjectInput in) {

        final ObjectInputWithMapper mapperIn ;
        if( in instanceof ObjectInputWithMapper ) {
            mapperIn = (ObjectInputWithMapper)in;
        } else {
            mapperIn = new ObjectInputWithMapper( in, this );
        }

        return mapperIn;

    }

    @Override
    public String toString() {
        List<String> typeNames = _serializers.keySet().stream().map(Key::getTypeName).collect(Collectors.toList());
        return format( "SerializerMapper: \n%s", String.join("\n", typeNames) );

    }

}
